package com.hanxiaozhang.threadbase1ndedition.no10threadpool;

import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

/**
 * 功能描述: <br>
 * 〈ForkJoinPool〉
 *
 * @Author:hanxinghua
 * @Date: 2021/11/25
 */
public class No13ForkJoinPool {

    /**
     * 生成数据数量
     */
    private static int[] NUMS = new int[1000000];

    /**
     * 分组数量
     */
    private static final int GROUP_NUM = 50000;

    /**
     * 随机数
     */
    static Random random = new Random();

    // 生成数组中的数据
    static {
        for (int i = 0; i < NUMS.length; i++) {
            NUMS[i] = random.nextInt(100);
        }
        System.out.println("---- 数组中数据总和: " + Arrays.stream(NUMS).sum());
    }


    public static void main(String[] args) throws IOException, InterruptedException {

        // 举例数据分段打印
        System.out.println("---- 举例数据分段打印");
        ForkJoinPool forkJoinPool1 = new ForkJoinPool();
        SubPrintTask subPrintTask = new SubPrintTask(0, NUMS.length);
        forkJoinPool1.execute(subPrintTask);
        forkJoinPool1.awaitTermination(2, TimeUnit.SECONDS);
        forkJoinPool1.shutdown();

        // 举例数据分段求和
        System.out.println("---- 举例数据分段求和");
        ForkJoinPool forkJoinPool2 = new ForkJoinPool();
        SubSumTask subSumTask = new SubSumTask(0, NUMS.length);
        forkJoinPool2.execute(subSumTask);
        long result = subSumTask.join();
        System.out.println("---- 数组中数据总和: " + result);


    }

    /**
     * 数据分段打印
     * <p>
     * 不带返回值的类
     */
    static class SubPrintTask extends RecursiveAction {

        private int start, end;

        public SubPrintTask(int s, int e) {
            start = s;
            end = e;
        }

        @Override
        protected void compute() {

            // 本批数据小于每批次数据数据量
            if (end - start <= GROUP_NUM) {
                long sum = 0L;
                for (int i = start; i < end; i++) {
                    sum += NUMS[i];
                }
                System.out.println("from:" + start + " to:" + end + " = " + sum);
            } else {
                // 本批数据大于每批次数据数据量，二分一下
                int middle = start + (end - start) / 2;
                SubPrintTask subTask1 = new SubPrintTask(start, middle);
                SubPrintTask subTask2 = new SubPrintTask(middle, end);
                subTask1.fork();
                subTask2.fork();
            }
        }
    }

    /**
     * 数据分段求和
     * <p>
     * 带返回值的类
     */
    static class SubSumTask extends RecursiveTask<Long> {

        private int start, end;

        public SubSumTask(int s, int e) {
            start = s;
            end = e;
        }

        @Override
        protected Long compute() {

            // 本批数据小于每批次数据数据量
            if (end - start <= GROUP_NUM) {
                long sum = 0L;
                for (int i = start; i < end; i++) {
                    sum += NUMS[i];
                }
                return sum;
            }

            // 本批数据大于每批次数据数据量
            int middle = start + (end - start) / 2;
            SubSumTask subTask1 = new SubSumTask(start, middle);
            SubSumTask subTask2 = new SubSumTask(middle, end);
            subTask1.fork();
            subTask2.fork();
            return subTask1.join() + subTask2.join();
        }

    }

}
